{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "from transformers import BertTokenizerFast\n",
    "import copy\n",
    "import torch\n",
    "from common.utils import Preprocessor\n",
    "import yaml\n",
    "import logging\n",
    "from pprint import pprint\n",
    "from IPython.core.debugger import set_trace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.DEBUG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from yaml import CLoader as Loader, CDumper as Dumper\n",
    "except ImportError:\n",
    "    from yaml import Loader, Dumper\n",
    "config = yaml.load(open(\"build_data_config.yaml\", \"r\"), Loader = yaml.FullLoader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_name = config[\"exp_name\"]\n",
    "data_in_dir = os.path.join(config[\"data_in_dir\"], exp_name)\n",
    "data_out_dir = os.path.join(config[\"data_out_dir\"], exp_name)\n",
    "if not os.path.exists(data_out_dir):\n",
    "    os.makedirs(data_out_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name2data = {}\n",
    "for path, folds, files in os.walk(data_in_dir):\n",
    "    for file_name in files:\n",
    "        file_path = os.path.join(path, file_name)\n",
    "        file_name = re.match(\"(.*?)\\.json\", file_name).group(1)\n",
    "        file_name2data[file_name] = json.load(open(file_path, \"r\", encoding = \"utf-8\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# @specific\n",
    "if config[\"encoder\"] == \"BERT\":\n",
    "    tokenizer = BertTokenizerFast.from_pretrained(config[\"bert_path\"], add_special_tokens = False, do_lower_case = False)\n",
    "    tokenize = tokenizer.tokenize\n",
    "    get_tok2char_span_map = lambda text: tokenizer.encode_plus(text, return_offsets_mapping = True, add_special_tokens = False)[\"offset_mapping\"]\n",
    "elif config[\"encoder\"] == \"BiLSTM\":\n",
    "    tokenize = lambda text: text.split(\" \")\n",
    "    def get_tok2char_span_map(text):\n",
    "        tokens = tokenize(text)\n",
    "        tok2char_span = []\n",
    "        char_num = 0\n",
    "        for tok in tokens:\n",
    "            tok2char_span.append((char_num, char_num + len(tok)))\n",
    "            char_num += len(tok) + 1 # +1: whitespace\n",
    "        return tok2char_span"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessor = Preprocessor(tokenize_func = tokenize, \n",
    "                            get_tok2char_span_map_func = get_tok2char_span_map)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "## Transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ori_format = config[\"ori_data_format\"]\n",
    "if ori_format != \"tplinker\": # if tplinker, skip transforming\n",
    "    for file_name, data in file_name2data.items():\n",
    "        if \"train\" in file_name:\n",
    "            data_type = \"train\"\n",
    "        if \"valid\" in file_name:\n",
    "            data_type = \"valid\"\n",
    "        if \"test\" in file_name:\n",
    "            data_type = \"test\"\n",
    "        data = preprocessor.transform_data(data, ori_format = ori_format, dataset_type = data_type, add_id = True)\n",
    "        file_name2data[file_name] = data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "## Clean and Add Spans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check token level span\n",
    "def check_tok_span(data):\n",
    "    def extr_ent(text, tok_span, tok2char_span):\n",
    "        char_span_list = tok2char_span[tok_span[0]:tok_span[1]]\n",
    "        char_span = (char_span_list[0][0], char_span_list[-1][1])\n",
    "        decoded_ent = text[char_span[0]:char_span[1]]\n",
    "        return decoded_ent\n",
    "\n",
    "    span_error_memory = set()\n",
    "    for sample in tqdm(data, desc = \"check tok spans\"):\n",
    "        text = sample[\"text\"]\n",
    "        tok2char_span = get_tok2char_span_map(text)\n",
    "        for ent in sample[\"entity_list\"]:\n",
    "            tok_span = ent[\"tok_span\"]\n",
    "            if extr_ent(text, tok_span, tok2char_span) != ent[\"text\"]:\n",
    "                span_error_memory.add(\"extr ent: {}---gold ent: {}\".format(extr_ent(text, tok_span, tok2char_span), ent[\"text\"]))\n",
    "                \n",
    "        for rel in sample[\"relation_list\"]:\n",
    "            subj_tok_span, obj_tok_span = rel[\"subj_tok_span\"], rel[\"obj_tok_span\"]\n",
    "            if extr_ent(text, subj_tok_span, tok2char_span) != rel[\"subject\"]:\n",
    "                span_error_memory.add(\"extr: {}---gold: {}\".format(extr_ent(text, subj_tok_span, tok2char_span), rel[\"subject\"]))\n",
    "            if extr_ent(text, obj_tok_span, tok2char_span) != rel[\"object\"]:\n",
    "                span_error_memory.add(\"extr: {}---gold: {}\".format(extr_ent(text, obj_tok_span, tok2char_span), rel[\"object\"]))\n",
    "                \n",
    "    return span_error_memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "clean data: 100%|██████████| 171227/171227 [00:01<00:00, 89466.07it/s]\n",
      "adding char level spans: 100%|██████████| 171227/171227 [00:27<00:00, 6232.56it/s]\n",
      "building relation type set and entity type set: 100%|██████████| 171227/171227 [00:00<00:00, 344443.98it/s]\n",
      "adding token level spans: 100%|██████████| 171227/171227 [00:33<00:00, 5181.82it/s]\n",
      "check tok spans: 100%|██████████| 171227/171227 [00:30<00:00, 5638.47it/s]\n",
      "clean data:  35%|███▌      | 7317/20665 [00:00<00:00, 73167.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'extr ent: Working Title Films2009---gold ent: Working Title Films', 'extr: 5New Balance---gold: New Balance', 'extr: Stand By Me09---gold: Stand By Me', 'extr ent: 1TCI---gold ent: TCI', 'extr: Jaime Olías---gold: Jaime Ol', 'extr: BBC1---gold: BBC', 'extr ent: Marc Angélil---gold ent: Marc Ang', 'extr ent: 银座ブルースワン---gold ent: 银座ブルース', 'extr: ⑴David .Chin---gold: David .Chin', 'extr ent: New BalanceNew---gold ent: New Balance', 'extr ent: Becky43---gold ent: Becky', 'extr: ★TF少年GO---gold: TF少年GO', 'extr: 粤情粤迷-不老情歌ⅢLRC---gold: 粤情粤迷-不老情歌Ⅲ', 'extr ent: 3M3M公司---gold ent: 3M公司', 'extr: IIGA---gold: IGA', 'extr: CCTV3---gold: CCTV', 'extr ent: TVB83---gold ent: TVB', 'extr ent: 2013MBC演艺大赏年度STAR奖---gold ent: MBC演艺大赏年度STAR奖', 'extr: SKE48の---gold: SKE48', 'extr ent: 1998TVB---gold ent: TVB', 'extr ent: NBCNBC---gold ent: NBC', 'extr: Twins2006---gold: Twins', 'extr: ☆★VicSHINee---gold: VicSHINee', 'extr ent: 绝世好BRA---gold ent: 绝世好B', 'extr: CTVSatelliteCommunicationsInc---gold: CTVS', 'extr ent: 粤情粤迷-不老情歌ⅢLRC---gold ent: 粤情粤迷-不老情歌Ⅲ', 'extr ent: 1CN5A科技广场---gold ent: CN5A科技广场', 'extr: Beyond1992---gold: Beyond', 'extr ent: SKE48に---gold ent: SKE48', 'extr: STemWin---gold: ST', 'extr ent: SINOInternational---gold ent: SINO', 'extr: Marc Angélil---gold: Marc Ang', 'extr ent: さらばGJ部---gold ent: GJ部', 'extr: 银座ブルースワン---gold: 银座ブルース', 'extr ent: Netflix2017---gold ent: Netflix', 'extr: Shine5---gold: Shine', 'extr: ～Harusource---gold: Haru', 'extr ent: CTVSatelliteCommunicationsInc---gold ent: CTVS', 'extr: TOTOTOTO---gold: TOTO', 'extr ent: 30ESET---gold ent: ESET', 'extr ent: 冒险岛2💜---gold ent: 冒险岛2', 'extr ent: Nicole21---gold ent: Nicole', 'extr: BalanceNew Balance---gold: New Balance', 'extr: GCBILL---gold: GC', 'extr ent: ←══╬∝←═7月1日---gold ent: 7月1日', 'extr: Becky43---gold: Becky', 'extr ent: MIKEVGNODN---gold ent: VGNODN', 'extr: SKE48に---gold: SKE48', 'extr: LOVO♡---gold: LOVO', 'extr: 30ESET---gold: ESET', 'extr: ～Tumblr---gold: Tumblr', 'extr ent: JerryC09---gold ent: JerryC', 'extr: NBCNBC---gold: NBC', 'extr ent: CCAA002---gold ent: CCAA', 'extr ent: Promise2008---gold ent: Promise', 'extr: Aesha Mohammadzai1---gold: Aesha Mohammadzai', 'extr: 2014SBS演艺大赏最佳MC奖---gold: SBS演艺大赏最佳MC奖', 'extr ent: SHES.H.E---gold ent: S.H.E', 'extr: Johnny Foote6---gold: Johnny Foote', 'extr ent: GCBILL---gold ent: GC', 'extr ent: BBC1---gold ent: BBC', 'extr ent: Stand By Me09---gold ent: Stand By Me', 'extr: 2011MTV封神榜音乐奖十大人气王---gold: MTV封神榜音乐奖十大人气王', 'extr ent: 日本NHK2006---gold ent: 日本NHK', 'extr ent: ⑴David .Chin---gold ent: David .Chin', 'extr ent: IECTR24722∶2007---gold ent: IEC', 'extr ent: HAHAMBC---gold ent: HAHA', 'extr ent: TVB1989---gold ent: TVB', 'extr: Roberto Gómez---gold: Roberto G', 'extr: P.A.WORKS×---gold: P.A.WORKS', 'extr ent: ☆★VicSHINee---gold ent: VicSHINee', 'extr ent: 郷田ほづみのふぁぼられ---gold ent: 郷田ほづみ', 'extr ent: QUTBBS---gold ent: QUT', 'extr ent: TOTOTOTO---gold ent: TOTO', 'extr ent: 8My heart will go on---gold ent: My heart will go on', 'extr ent: CNS1952---gold ent: CNS', 'extr: TVB1989---gold: TVB', 'extr: 1996TVB---gold: TVB', 'extr ent: IIGA---gold ent: IGA', 'extr: 冒险岛2💜---gold: 冒险岛2', 'extr: CNS1952---gold: CNS', 'extr: Tank2007---gold: Tank', 'extr ent: 魔力TV1---gold ent: 魔力TV', 'extr: New BalanceNew---gold: New Balance', 'extr: 日本NHK2006---gold: 日本NHK', 'extr ent: 2014Hito流行音乐奖高音质Hito最潜力女声奖---gold ent: Hito流行音乐奖高音质Hito最潜力女声奖', 'extr: YumiMiko---gold: Miko', 'extr: BIH1984---gold: BIH', 'extr ent: SKE48の---gold ent: SKE48', 'extr: Netflix2017---gold: Netflix', 'extr: 1TCI---gold: TCI', 'extr: CCTV12---gold: CCTV', 'extr: ISO9000---gold: ISO', 'extr: Diamond25---gold: Diamond', 'extr ent: 48SKE48アゲぽよボウル48---gold ent: SKE48', 'extr ent: CCTV3---gold ent: CCTV', 'extr: 💜VALENTINO中国---gold: VALENTINO中国', 'extr ent: FMS.H.E---gold ent: S.H.E', 'extr: 1998TVB---gold: TVB', 'extr ent: IGIIGI---gold ent: IGI', 'extr: IECTR24722∶2007---gold: IEC', 'extr: ★Dangerous Liaisons---gold: Dangerous Liaisons', 'extr: 特殊案件专案组TEN2---gold: 特殊案件专案组TEN', 'extr: Promise2008---gold: Promise', 'extr: 暴力街区13---gold: 暴力街区1', 'extr: 😱😱😱TATA木门---gold: TATA木门', 'extr ent: The Beatles Code2---gold ent: The Beatles Code', 'extr ent: 2SONOS---gold ent: SONOS', 'extr: KBS2---gold: KBS', 'extr ent: SKEで---gold ent: SKE', 'extr: QUTBBS---gold: QUT', 'extr ent: JHRobot---gold ent: JHR', 'extr: Working Title Films2009---gold: Working Title Films', 'extr: 01Heroes---gold: Heroes', 'extr: CCAA002---gold: CCAA', 'extr: Daniel Cebrián---gold: Daniel Cebri', 'extr: SHES.H.E---gold: S.H.E', 'extr: TVB83---gold: TVB', 'extr ent: TVB28---gold ent: TVB', 'extr: JHRobot---gold: JHR', 'extr ent: Jaime Olías---gold ent: Jaime Ol', 'extr ent: ～Tumblr---gold ent: Tumblr', 'extr ent: 2～SPEC---gold ent: SPEC', 'extr ent: 1996TVB---gold ent: TVB', 'extr ent: LOVO♡---gold ent: LOVO', 'extr ent: 暴力街区13---gold ent: 暴力街区1', 'extr: BBC95---gold: BBC', 'extr: IEC60745---gold: IEC', 'extr: TVB1999---gold: TVB', 'extr ent: KBS2TV---gold ent: KBS', 'extr ent: BalanceNew Balance---gold ent: New Balance', 'extr ent: IEC118---gold ent: IEC', 'extr ent: Twins2006---gold ent: Twins', 'extr: 詹姆斯卡梅隆\\xad主演 莱昂纳多迪卡普里奥 凯特温丝莱特\\xad剧情 不用我多说吧\\xad个人点评\\xad《泰坦尼克号》莱昂纳多的成名作 一部灾难电影 却让人完全沉浸在男女主人公的爱情里 一个愿意为爱付出生命 一个可以为爱人放弃所有富贵 俩人不顾世俗的看法选择相爱 当你搂着喜欢的人一起看完这个电影后会不会有些许的感悟---gold: 詹姆斯卡梅隆\\xad', 'extr ent: 😱😱😱TATA木门---gold ent: TATA木门', 'extr ent: いだろうさいはての湖---gold ent: さいはての湖', 'extr: ＝＝＝＝＝＝＝ROCKY君---gold: ROCKY君', 'extr ent: YumiMiko---gold ent: Miko', 'extr: 48SKE48アゲぽよボウル48---gold: SKE48', 'extr: Wands1994---gold: Wands', 'extr ent: S.H.E2008---gold ent: S.H.E', 'extr: Nicole21---gold: Nicole', 'extr ent: 特殊案件专案组TEN2---gold ent: 特殊案件专案组TEN', 'extr: Rafael Tolói---gold: Rafael Tol', 'extr ent: Prada2013---gold ent: Prada', 'extr ent: ★Dangerous Liaisons---gold ent: Dangerous Liaisons', 'extr: SINOInternational---gold: SINO', 'extr ent: STemWin---gold ent: ST', 'extr ent: 北国は寒いだろうさいはての---gold ent: 北国は寒いだろう', 'extr: The Beatles Code2---gold: The Beatles Code', 'extr: Beyond1993---gold: Beyond', 'extr: ←══╬∝←═7月1日---gold: 7月1日', 'extr: SPEC～---gold: SPEC', 'extr ent: Beyond4---gold ent: Beyond', 'extr ent: YumiMiko---gold ent: Yumi', 'extr: 2013MBC演艺大赏年度STAR奖---gold: MBC演艺大赏年度STAR奖', 'extr ent: Rafael Tolói---gold ent: Rafael Tol', 'extr: KBS2TV---gold: KBS', 'extr: EddieBrock3---gold: EddieBrock', 'extr ent: 01Heroes---gold ent: Heroes', 'extr ent: AlphabetInc---gold ent: Alphabet', 'extr: 30MBC演技大赏 迷你剧男子优秀演技奖---gold: MBC演技大赏 迷你剧男子优秀演技奖', 'extr ent: Roberto Gómez---gold ent: Roberto G', 'extr ent: BIH1984---gold ent: BIH', 'extr: 7BBC---gold: BBC', 'extr: AlphabetInc---gold: Alphabet', 'extr ent: Crystal✎---gold ent: Crystal', 'extr: 2014Hito流行音乐奖高音质Hito最潜力女声奖---gold: Hito流行音乐奖高音质Hito最潜力女声奖', 'extr ent: IEC60745---gold ent: IEC', 'extr ent: ISO9000---gold ent: ISO', 'extr ent: ASTMF2085---gold ent: ASTM', 'extr: Prada2013---gold: Prada', 'extr: 绝世好BRA---gold: 绝世好B', 'extr: 2SONOS---gold: SONOS', 'extr: IGIIGI---gold: IGI', 'extr ent: Tank2007---gold ent: Tank', 'extr: 3M3M公司---gold: 3M公司', 'extr ent: Johnny Foote6---gold ent: Johnny Foote', 'extr: HAHAMBC---gold: HAHA', 'extr: 8My heart will go on---gold: My heart will go on', 'extr: Crystal★---gold: Crystal', 'extr ent: Daniel Cebrián---gold ent: Daniel Cebri', 'extr ent: 30MBC演技大赏 迷你剧男子优秀演技奖---gold ent: MBC演技大赏 迷你剧男子优秀演技奖', 'extr: ●To lose in amber---gold: To lose in amber', 'extr ent: 2014SBS演艺大赏最佳MC奖---gold ent: SBS演艺大赏最佳MC奖', 'extr: MojangAB---gold: Mojang', 'extr: さらばGJ部---gold: GJ部', 'extr: ASTMF2085---gold: ASTM', 'extr: TVB28---gold: TVB', 'extr ent: ＝＝＝＝＝＝＝ROCKY君---gold ent: ROCKY君', 'extr ent: CCTV12---gold ent: CCTV', 'extr ent: Robert Bridge5---gold ent: Robert Bridge', 'extr: Robert Bridge5---gold: Robert Bridge', 'extr: ☛7月27日---gold: 7月27日', 'extr ent: ●To lose in amber---gold ent: To lose in amber', 'extr: 2006KBS明星发掘大赛冠军---gold: KBS明星发掘大赛冠军', 'extr ent: CCTV2---gold ent: CCTV', 'extr: Beyond4---gold: Beyond', 'extr: いだろうさいはての湖---gold: さいはての湖', 'extr ent: 2011MTV封神榜音乐奖十大人气王---gold ent: MTV封神榜音乐奖十大人气王', 'extr: JerryC02---gold: JerryC', 'extr: ExcelHome---gold: Excel', 'extr: AB型君---gold: B型君', 'extr: YumiMiko---gold: Yumi', 'extr: 郷田ほづみのふぁぼられ---gold: 郷田ほづみ', 'extr ent: 💜VALENTINO中国---gold ent: VALENTINO中国', 'extr ent: TVB1999---gold ent: TVB', 'extr ent: ◆GUCCI---gold ent: GUCCI', 'extr: MIKEVGNODN---gold: VGNODN', 'extr ent: AB型君---gold ent: B型君', 'extr ent: B1A4---gold ent: B1A', 'extr ent: ★TF少年GO---gold ent: TF少年GO', 'extr ent: BBC95---gold ent: BBC', 'extr ent: Shine5---gold ent: Shine', 'extr ent: ～Harusource---gold ent: Haru', 'extr: Jpod---gold: J', 'extr: S.H.E2008---gold: S.H.E', 'extr: ◆GUCCI---gold: GUCCI', 'extr: CCTV2---gold: CCTV', 'extr ent: Wands1994---gold ent: Wands', 'extr ent: JerryC02---gold ent: JerryC', 'extr: 1CN5A科技广场---gold: CN5A科技广场', 'extr ent: KBS2---gold ent: KBS', 'extr ent: ☛7月27日---gold ent: 7月27日', 'extr ent: 2006KBS明星发掘大赛冠军---gold ent: KBS明星发掘大赛冠军', 'extr: JerryC09---gold: JerryC', 'extr: 2007SBS演技大赏连续剧最佳配角奖---gold: SBS演技大赏连续剧最佳配角奖', 'extr ent: Crystal★---gold ent: Crystal', 'extr: FMS.H.E---gold: S.H.E', 'extr: 2～SPEC---gold: SPEC', 'extr ent: 5New Balance---gold ent: New Balance', 'extr: 魔力TV1---gold: 魔力TV', 'extr ent: Jpod---gold ent: J', 'extr ent: Beyond1992---gold ent: Beyond', 'extr: SKEで---gold: SKE', 'extr ent: Beyond1993---gold ent: Beyond', 'extr: 北国は寒いだろうさいはての---gold: 北国は寒いだろう', 'extr ent: P.A.WORKS×---gold ent: P.A.WORKS', 'extr ent: EddieBrock3---gold ent: EddieBrock', 'extr ent: Diamond25---gold ent: Diamond', 'extr: Crystal✎---gold: Crystal', 'extr ent: Aesha Mohammadzai1---gold ent: Aesha Mohammadzai', 'extr ent: SPEC～---gold ent: SPEC', 'extr ent: MojangAB---gold ent: Mojang', 'extr: B1A4---gold: B1A', 'extr ent: 詹姆斯卡梅隆\\xad主演 莱昂纳多迪卡普里奥 凯特温丝莱特\\xad剧情 不用我多说吧\\xad个人点评\\xad《泰坦尼克号》莱昂纳多的成名作 一部灾难电影 却让人完全沉浸在男女主人公的爱情里 一个愿意为爱付出生命 一个可以为爱人放弃所有富贵 俩人不顾世俗的看法选择相爱 当你搂着喜欢的人一起看完这个电影后会不会有些许的感悟---gold ent: 詹姆斯卡梅隆\\xad', 'extr ent: 2007SBS演技大赏连续剧最佳配角奖---gold ent: SBS演技大赏连续剧最佳配角奖', 'extr: IEC118---gold: IEC', 'extr ent: 7BBC---gold ent: BBC', 'extr ent: ExcelHome---gold ent: Excel'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "clean data: 100%|██████████| 20665/20665 [00:00<00:00, 71854.51it/s]\n",
      "adding char level spans: 100%|██████████| 20665/20665 [00:03<00:00, 6229.23it/s]\n",
      "building relation type set and entity type set: 100%|██████████| 20665/20665 [00:00<00:00, 286857.05it/s]\n",
      "adding token level spans: 100%|██████████| 20665/20665 [00:04<00:00, 4507.92it/s]\n",
      "check tok spans: 100%|██████████| 20665/20665 [00:03<00:00, 5659.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'extr: 泠泠七弦ZLH2---gold: 泠泠七弦ZLH', 'extr: ∶KRAD---gold: KRAD', 'extr ent: MH&M---gold ent: H&M', 'extr ent: H&MH---gold ent: H&M', 'extr ent: PuRiPaRa2014年7月5日---gold ent: 2014年7月5日', 'extr ent: S.H.E39---gold ent: S.H.E', 'extr ent: Ⅱ2014年7月5日---gold ent: 2014年7月5日', 'extr ent: 泠泠七弦ZLH2---gold ent: 泠泠七弦ZLH', 'extr: RevivePuzzle---gold: Revive', 'extr ent: Jasper2---gold ent: Jasper', 'extr: The Song Sisters---gold: The Song Sister', 'extr ent: NetflixNetflix---gold ent: Netflix', 'extr: PCTCGP---gold: PCTCG', 'extr ent: CCTV5NBA最前线---gold ent: NBA最前线', 'extr ent: Crystal2014年7月5日---gold ent: 2014年7月5日', 'extr ent: RevivePuzzle---gold ent: Revive', 'extr ent: NicoNico---gold ent: Nico', 'extr: MH&M---gold: H&M', 'extr ent: PCTCGP---gold ent: PCTCG', 'extr ent: PopeThanawat---gold ent: Pope', 'extr ent: Tanyaraes---gold ent: Tanya', 'extr: NicoNico---gold: Nico', 'extr ent: SNH48▪---gold ent: SNH48', 'extr: SNH48▪---gold: SNH48', 'extr: NetflixNetflix---gold: Netflix', 'extr: CCTV5NBA最前线---gold: NBA最前线', 'extr ent: 美少女战士Crystal2014---gold ent: 美少女战士Crystal', 'extr: TVB1996---gold: TVB', 'extr ent: PGI®---gold ent: PGI', 'extr: PuRiPaRa2014年7月5日---gold: 2014年7月5日', 'extr: PGI®---gold: PGI', 'extr ent: ChokChokChok---gold ent: Chok', 'extr ent: TVB1998---gold ent: TVB', 'extr: Jasper2---gold: Jasper', \"extr: Vin'Selection1---gold: Vin'Selection\", 'extr ent: TVB1976年---gold ent: 1976年', 'extr: PopeThanawat---gold: Pope', 'extr: Crystal2014年7月5日---gold: 2014年7月5日', 'extr: TVB1976年---gold: 1976年', 'extr: H&MH---gold: H&M', 'extr ent: TVB1996---gold ent: TVB', 'extr ent: The Song Sisters---gold ent: The Song Sister', 'extr: 美少女战士Crystal2014---gold: 美少女战士Crystal', 'extr ent: ∶KRAD---gold ent: KRAD', 'extr: Tanyaraes---gold: Tanya', 'extr: TVB1998---gold: TVB', \"extr ent: Vin'Selection1---gold ent: Vin'Selection\", 'extr: ChokChokChok---gold: Chok', 'extr: S.H.E39---gold: S.H.E', 'extr: Ⅱ2014年7月5日---gold: 2014年7月5日'}\n",
      "{'train_data': {'miss_samples': 282, 'tok_span_error': 250},\n",
      " 'valid_data': {'miss_samples': 33, 'tok_span_error': 50}}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# clean, add char span, tok span\n",
    "# collect relations\n",
    "# check tok spans\n",
    "rel_set = set()\n",
    "ent_set = set()\n",
    "error_statistics = {}\n",
    "for file_name, data in file_name2data.items():\n",
    "    assert len(data) > 0\n",
    "    if \"relation_list\" in data[0]: # train or valid data\n",
    "        # rm redundant whitespaces\n",
    "        # separate by whitespaces\n",
    "        data = preprocessor.clean_data_wo_span(data, separate = config[\"separate_char_by_white\"])\n",
    "        error_statistics[file_name] = {}\n",
    "#         if file_name != \"train_data\":\n",
    "#             set_trace()\n",
    "        # add char span\n",
    "        if config[\"add_char_span\"]:\n",
    "            data, miss_sample_list = preprocessor.add_char_span(data, config[\"ignore_subword\"])\n",
    "            error_statistics[file_name][\"miss_samples\"] = len(miss_sample_list)\n",
    "            \n",
    "#         # clean\n",
    "#         data, bad_samples_w_char_span_error = preprocessor.clean_data_w_span(data)\n",
    "#         error_statistics[file_name][\"char_span_error\"] = len(bad_samples_w_char_span_error)\n",
    "                            \n",
    "        # collect relation types and entity types\n",
    "        for sample in tqdm(data, desc = \"building relation type set and entity type set\"):\n",
    "            if \"entity_list\" not in sample: # if \"entity_list\" not in sample, generate entity list with default type\n",
    "                ent_list = []\n",
    "                for rel in sample[\"relation_list\"]:\n",
    "                    ent_list.append({\n",
    "                        \"text\": rel[\"subject\"],\n",
    "                        \"type\": \"DEFAULT\",\n",
    "                        \"char_span\": rel[\"subj_char_span\"],\n",
    "                    })\n",
    "                    ent_list.append({\n",
    "                        \"text\": rel[\"object\"],\n",
    "                        \"type\": \"DEFAULT\",\n",
    "                        \"char_span\": rel[\"obj_char_span\"],\n",
    "                    })\n",
    "                sample[\"entity_list\"] = ent_list\n",
    "            \n",
    "            for ent in sample[\"entity_list\"]:\n",
    "                ent_set.add(ent[\"type\"])\n",
    "                \n",
    "            for rel in sample[\"relation_list\"]:\n",
    "                rel_set.add(rel[\"predicate\"])\n",
    "               \n",
    "        # add tok span\n",
    "        data = preprocessor.add_tok_span(data)\n",
    "\n",
    "        # check tok span\n",
    "        if config[\"check_tok_span\"]:\n",
    "            span_error_memory = check_tok_span(data)\n",
    "            if len(span_error_memory) > 0:\n",
    "                print(span_error_memory)\n",
    "            error_statistics[file_name][\"tok_span_error\"] = len(span_error_memory)\n",
    "            \n",
    "        file_name2data[file_name] = data\n",
    "pprint(error_statistics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Output to Disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:train_data is output to ../data4bert/duie2/train_data.json\n",
      "INFO:root:test_data is output to ../data4bert/duie2/test_data.json\n",
      "INFO:root:valid_data is output to ../data4bert/duie2/valid_data.json\n",
      "INFO:root:rel2id is output to ../data4bert/duie2/rel2id.json\n",
      "INFO:root:ent2id is output to ../data4bert/duie2/ent2id.json\n",
      "INFO:root:data_statistics is output to ../data4bert/duie2/data_statistics.txt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'entity_type_num': 26,\n",
      " 'relation_type_num': 52,\n",
      " 'test_data': 101311,\n",
      " 'train_data': 171227,\n",
      " 'valid_data': 20665}\n"
     ]
    }
   ],
   "source": [
    "rel_set = sorted(rel_set)\n",
    "rel2id = {rel:ind for ind, rel in enumerate(rel_set)}\n",
    "\n",
    "ent_set = sorted(ent_set)\n",
    "ent2id = {ent:ind for ind, ent in enumerate(ent_set)}\n",
    "\n",
    "data_statistics = {\n",
    "    \"relation_type_num\": len(rel2id),\n",
    "    \"entity_type_num\": len(ent2id),\n",
    "}\n",
    "\n",
    "for file_name, data in file_name2data.items():\n",
    "    data_path = os.path.join(data_out_dir, \"{}.json\".format(file_name))\n",
    "    json.dump(data, open(data_path, \"w\", encoding = \"utf-8\"), ensure_ascii = False)\n",
    "    logging.info(\"{} is output to {}\".format(file_name, data_path))\n",
    "    data_statistics[file_name] = len(data)\n",
    "\n",
    "rel2id_path = os.path.join(data_out_dir, \"rel2id.json\")\n",
    "json.dump(rel2id, open(rel2id_path, \"w\", encoding = \"utf-8\"), ensure_ascii = False)\n",
    "logging.info(\"rel2id is output to {}\".format(rel2id_path))\n",
    "\n",
    "ent2id_path = os.path.join(data_out_dir, \"ent2id.json\")\n",
    "json.dump(ent2id, open(ent2id_path, \"w\", encoding = \"utf-8\"), ensure_ascii = False)\n",
    "logging.info(\"ent2id is output to {}\".format(ent2id_path))\n",
    "\n",
    "\n",
    "\n",
    "data_statistics_path = os.path.join(data_out_dir, \"data_statistics.txt\")\n",
    "json.dump(data_statistics, open(data_statistics_path, \"w\", encoding = \"utf-8\"), ensure_ascii = False, indent = 4)\n",
    "logging.info(\"data_statistics is output to {}\".format(data_statistics_path)) \n",
    "\n",
    "pprint(data_statistics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Genrate WordDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "if config[\"encoder\"] in {\"BiLSTM\", }:\n",
    "    all_data = []\n",
    "    for data in list(file_name2data.values()):\n",
    "        all_data.extend(data)\n",
    "        \n",
    "    token2num = {}\n",
    "    for sample in tqdm(all_data, desc = \"Tokenizing\"):\n",
    "        text = sample['text']\n",
    "        for tok in tokenize(text):\n",
    "            token2num[tok] = token2num.get(tok, 0) + 1\n",
    "    \n",
    "    token2num = dict(sorted(token2num.items(), key = lambda x: x[1], reverse = True))\n",
    "    max_token_num = 50000\n",
    "    token_set = set()\n",
    "    for tok, num in tqdm(token2num.items(), desc = \"Filter uncommon words\"):\n",
    "        if num < 3: # filter words with a frequency of less than 3\n",
    "            continue\n",
    "        token_set.add(tok)\n",
    "        if len(token_set) == max_token_num:\n",
    "            break\n",
    "        \n",
    "    token2idx = {tok:idx + 2 for idx, tok in enumerate(sorted(token_set))}\n",
    "    token2idx[\"<PAD>\"] = 0\n",
    "    token2idx[\"<UNK>\"] = 1\n",
    "#     idx2token = {idx:tok for tok, idx in token2idx.items()}\n",
    "    \n",
    "    dict_path = os.path.join(data_out_dir, \"token2idx.json\")\n",
    "    json.dump(token2idx, open(dict_path, \"w\", encoding = \"utf-8\"), ensure_ascii = False, indent = 4)\n",
    "    logging.info(\"token2idx is output to {}, total token num: {}\".format(dict_path, len(token2idx))) "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
